{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tune T5 Model with PyTorchJob and FSDP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This Notebook will fine-tune Text-to-Text Transfer Transformer (T5) with Wikihow dataset for text summarization using Kubeflow PyTorchJob.\n",
    "\n",
    "Pretrained T5 model: https://huggingface.co/google-t5/t5-base\n",
    "\n",
    "Wikihow dataset: https://github.com/mahnazkoupaee/WikiHow-Dataset\n",
    "\n",
    "This Notebook will use **4** GPUs to fine-tune T5 model on 2 Nodes. This example is based on [the official PyTorch FSDP tutorial](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## FSDP with multi-node multi-worker training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This Notebook demonstrates multi-node, multi-worker distributed training with Fully Sharded Data Parallel (FSDP) and PyTorchJob.\n",
    "\n",
    "When a model is trained with FSDP, the GPU memory footprint is smaller compare to Distributed Data Parallel (DDP),\n",
    "as the model parameters are sharded across GPU devices.\n",
    "\n",
    "This enables training of very large models that would otherwise be impossible to fit on a single GPU device.\n",
    "\n",
    "Check this guide to learn more about PyTorch FSDP: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Install the required packages\n",
    "\n",
    "Install the Kubeflow Training Python SDK."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# TODO (andreyvelich): Use the release version of SDK.\n",
    "!pip install git+https://github.com/kubeflow/training-operator.git#subdirectory=sdk/python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create script to fine-tune T5 using FSDP\n",
    "\n",
    "We need to wrap our fine-tuning script in a function to create Kubeflow PyTorchJob."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train_function(parameters):\n",
    "    import os\n",
    "    import time\n",
    "    import functools\n",
    "\n",
    "    import torch\n",
    "    import torch.distributed as dist\n",
    "    from torch.utils.data.distributed import DistributedSampler\n",
    "    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP\n",
    "    from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy\n",
    "\n",
    "    from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "    from transformers.models.t5.modeling_t5 import T5Block\n",
    "    from datasets import Dataset\n",
    "\n",
    "    # [1] Setup PyTorch distributed and get the distributed parameters.\n",
    "    dist.init_process_group(\"nccl\")\n",
    "    local_rank = int(os.environ[\"LOCAL_RANK\"])\n",
    "    rank = dist.get_rank()\n",
    "    world_size = dist.get_world_size()\n",
    "\n",
    "    # Local rank identifies the GPU number inside the pod.\n",
    "    torch.cuda.set_device(local_rank)\n",
    "\n",
    "    print(\n",
    "        f\"FSDP Training for WORLD_SIZE: {world_size}, RANK: {rank}, LOCAL_RANK: {local_rank}\"\n",
    "    )\n",
    "\n",
    "    # [2] Prepare the Wikihow dataset\n",
    "    class wikihow(torch.utils.data.Dataset):\n",
    "        def __init__(\n",
    "            self,\n",
    "            tokenizer,\n",
    "            num_samples,\n",
    "            input_length,\n",
    "            output_length,\n",
    "        ):\n",
    "\n",
    "            self.dataset = Dataset.from_csv(parameters[\"DATASET_URL\"])\n",
    "            self.dataset = self.dataset.select(list(range(0, num_samples)))\n",
    "            self.input_length = input_length\n",
    "            self.tokenizer = tokenizer\n",
    "            self.output_length = output_length\n",
    "\n",
    "        def __len__(self):\n",
    "            return self.dataset.shape[0]\n",
    "\n",
    "        def clean_text(self, text):\n",
    "            # Dataset contains empty values.\n",
    "            if text is None:\n",
    "                return \"\"\n",
    "            text = text.replace(\"Example of text:\", \"\")\n",
    "            text = text.replace(\"Example of Summary:\", \"\")\n",
    "            text = text.replace(\"\\n\", \"\")\n",
    "            text = text.replace(\"``\", \"\")\n",
    "            text = text.replace('\"', \"\")\n",
    "\n",
    "            return text\n",
    "\n",
    "        def convert_to_features(self, example_batch):\n",
    "            # Tokenize text and headline (as pairs of inputs).\n",
    "            input_ = self.clean_text(example_batch[\"text\"])\n",
    "            target_ = self.clean_text(example_batch[\"headline\"])\n",
    "\n",
    "            source = self.tokenizer.batch_encode_plus(\n",
    "                [input_],\n",
    "                max_length=self.input_length,\n",
    "                padding=\"max_length\",\n",
    "                truncation=True,\n",
    "                return_tensors=\"pt\",\n",
    "            )\n",
    "\n",
    "            targets = self.tokenizer.batch_encode_plus(\n",
    "                [target_],\n",
    "                max_length=self.output_length,\n",
    "                padding=\"max_length\",\n",
    "                truncation=True,\n",
    "                return_tensors=\"pt\",\n",
    "            )\n",
    "\n",
    "            return source, targets\n",
    "\n",
    "        def __getitem__(self, index):\n",
    "            source, targets = self.convert_to_features(self.dataset[index])\n",
    "\n",
    "            source_ids = source[\"input_ids\"].squeeze()\n",
    "            target_ids = targets[\"input_ids\"].squeeze()\n",
    "\n",
    "            src_mask = source[\"attention_mask\"].squeeze()\n",
    "            target_mask = targets[\"attention_mask\"].squeeze()\n",
    "\n",
    "            return {\n",
    "                \"source_ids\": source_ids,\n",
    "                \"source_mask\": src_mask,\n",
    "                \"target_ids\": target_ids,\n",
    "                \"target_mask\": target_mask,\n",
    "            }\n",
    "\n",
    "    # [3] Get the T5 pre-trained model and tokenizer.\n",
    "    # Since this script is run by multiple workers, we should print results only for the worker with RANK=0.\n",
    "    if rank == 0:\n",
    "        print(f\"Downloading the {parameters['MODEL_NAME']} model\")\n",
    "\n",
    "    model = T5ForConditionalGeneration.from_pretrained(parameters[\"MODEL_NAME\"])\n",
    "    tokenizer = T5Tokenizer.from_pretrained(parameters[\"MODEL_NAME\"])\n",
    "\n",
    "    # [4] Download the Wikihow dataset.\n",
    "    if rank == 0:\n",
    "        print(\"Downloading the Wikihow dataset\")\n",
    "\n",
    "    dataset = wikihow(tokenizer, 1500, 512, 150)\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        dataset,\n",
    "        batch_size=4,\n",
    "        sampler=DistributedSampler(dataset),\n",
    "    )\n",
    "\n",
    "    # [5] Setup model with FSDP.\n",
    "    # Model is on CPU before input to FSDP.\n",
    "    t5_auto_wrap_policy = functools.partial(\n",
    "        transformer_auto_wrap_policy,\n",
    "        transformer_layer_cls={\n",
    "            T5Block,\n",
    "        },\n",
    "    )\n",
    "    model = FSDP(\n",
    "        model,\n",
    "        auto_wrap_policy=t5_auto_wrap_policy,\n",
    "        device_id=torch.cuda.current_device(),\n",
    "    )\n",
    "\n",
    "    # [6] Start training.\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=0.002)\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)\n",
    "    t0 = time.time()\n",
    "    if rank == 0:\n",
    "        print(\"Training is started...\")\n",
    "\n",
    "    for epoch in range(1, 3):\n",
    "        model.train()\n",
    "        fsdp_loss = torch.zeros(2).to(local_rank)\n",
    "\n",
    "        for batch in train_loader:\n",
    "            for key in batch.keys():\n",
    "                batch[key] = batch[key].to(local_rank)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            output = model(\n",
    "                input_ids=batch[\"source_ids\"],\n",
    "                attention_mask=batch[\"source_mask\"],\n",
    "                labels=batch[\"target_ids\"],\n",
    "            )\n",
    "            loss = output[\"loss\"]\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            fsdp_loss[0] += loss.item()\n",
    "            fsdp_loss[1] += len(batch)\n",
    "\n",
    "        dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)\n",
    "        train_accuracy = fsdp_loss[0] / fsdp_loss[1]\n",
    "\n",
    "        if rank == 0:\n",
    "            print(f\"Train Epoch: \\t{epoch}, Loss: \\t{train_accuracy:.4f}\")\n",
    "\n",
    "        scheduler.step()\n",
    "\n",
    "    dist.barrier()\n",
    "\n",
    "    if rank == 0:\n",
    "        print(f\"FSDP training time: {int(time.time() - t0)} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Kubeflow PyTorchJob to fine-tune T5 with FSDP\n",
    "\n",
    "Use `TrainingClient()` to create PyTorchJob which will fine-tine T5 on **2 workers** using **2 GPU** for each worker.\n",
    "\n",
    "If you don't have enough GPU resources, you can decrease number of workers or number of GPUs per worker."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from kubeflow.training import TrainingClient\n",
    "\n",
    "job_name = \"fsdp-fine-tuning\"\n",
    "\n",
    "parameters = {\n",
    "    \"DATASET_URL\": \"https://public-nlp-datasets.s3.us-west-2.amazonaws.com/wikihowAll.csv\",\n",
    "    \"MODEL_NAME\": \"t5-base\",\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create the PyTorchJob.\n",
    "TrainingClient().create_job(\n",
    "    name=job_name,\n",
    "    train_func=train_function,\n",
    "    parameters=parameters,\n",
    "    num_workers=2, # You can modify number of workers or number of GPUs.\n",
    "    num_procs_per_worker=2,\n",
    "    resources_per_worker={\"gpu\": 2},\n",
    "    packages_to_install=[\n",
    "        \"transformers==4.38.2\",\n",
    "        \"datasets==2.21.0\",\n",
    "        \"SentencePiece==0.2.0\",\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Check the PyTorchJob conditions\n",
    "\n",
    "Use `TrainingClient()` APIs to get information about created PyTorchJob."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorchJob Conditions\n",
      "[{'last_transition_time': datetime.datetime(2024, 10, 16, 19, 24, 17, tzinfo=tzutc()),\n",
      " 'last_update_time': datetime.datetime(2024, 10, 16, 19, 24, 17, tzinfo=tzutc()),\n",
      " 'message': 'PyTorchJob fsdp-fine-tuning is created.',\n",
      " 'reason': 'PyTorchJobCreated',\n",
      " 'status': 'True',\n",
      " 'type': 'Created'}, {'last_transition_time': datetime.datetime(2024, 10, 16, 19, 24, 18, tzinfo=tzutc()),\n",
      " 'last_update_time': datetime.datetime(2024, 10, 16, 19, 24, 18, tzinfo=tzutc()),\n",
      " 'message': 'PyTorchJob fsdp-fine-tuning is running.',\n",
      " 'reason': 'PyTorchJobRunning',\n",
      " 'status': 'True',\n",
      " 'type': 'Running'}]\n",
      "----------------------------------------\n",
      "PyTorchJob is running\n"
     ]
    }
   ],
   "source": [
    "print(\"PyTorchJob Conditions\")\n",
    "print(TrainingClient().get_job_conditions(job_name))\n",
    "print(\"-\" * 40)\n",
    "\n",
    "# Wait until PyTorchJob has the Running condition.\n",
    "job = TrainingClient().wait_for_job_conditions(\n",
    "    job_name,\n",
    "    expected_conditions={\"Running\"},\n",
    ")\n",
    "print(\"PyTorchJob is running\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the PyTorchJob pod names\n",
    "\n",
    "Since we define 2 workers, PyTorchJob will create 1 master pod and 1 worker pod to run FSDP fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['fsdp-fine-tuning-master-0', 'fsdp-fine-tuning-worker-0']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TrainingClient().get_job_pod_names(job_name)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {
     "iopub.status.busy": "2022-09-01T20:10:25.759950Z",
     "iopub.status.idle": "2022-09-01T20:10:25.760581Z",
     "shell.execute_reply": "2022-09-01T20:10:25.760353Z",
     "shell.execute_reply.started": "2022-09-01T20:10:25.760328Z"
    },
    "tags": []
   },
   "source": [
    "### Get the PyTorchJob training logs\n",
    "\n",
    "Model parameters are sharded across all workers and GPU devices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Pod fsdp-fine-tuning-master-0]: WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n",
      "[Pod fsdp-fine-tuning-master-0]: [2024-10-16 19:24:47,178] torch.distributed.run: [WARNING] \n",
      "[Pod fsdp-fine-tuning-master-0]: [2024-10-16 19:24:47,178] torch.distributed.run: [WARNING] *****************************************\n",
      "[Pod fsdp-fine-tuning-master-0]: [2024-10-16 19:24:47,178] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n",
      "[Pod fsdp-fine-tuning-master-0]: [2024-10-16 19:24:47,178] torch.distributed.run: [WARNING] *****************************************\n",
      "[Pod fsdp-fine-tuning-master-0]: FSDP Training for WORLD_SIZE: 4, RANK: 0, LOCAL_RANK: 0\n",
      "[Pod fsdp-fine-tuning-master-0]: Downloading the t5-base model\n",
      "[Pod fsdp-fine-tuning-master-0]: /opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "[Pod fsdp-fine-tuning-master-0]:   warnings.warn(\n",
      "[Pod fsdp-fine-tuning-master-0]: FSDP Training for WORLD_SIZE: 4, RANK: 1, LOCAL_RANK: 1\n",
      "[Pod fsdp-fine-tuning-master-0]: /opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "[Pod fsdp-fine-tuning-master-0]:   warnings.warn(\n",
      "[Pod fsdp-fine-tuning-master-0]: You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
      "[Pod fsdp-fine-tuning-master-0]: You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
      "[Pod fsdp-fine-tuning-master-0]: Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "[Pod fsdp-fine-tuning-master-0]: Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "[Pod fsdp-fine-tuning-master-0]: Downloading the Wikihow dataset\n",
      "Downloading data: 100%|██████████| 619M/619M [00:11<00:00, 55.0MB/s] \n",
      "Generating train split: 215365 examples [00:08, 26087.53 examples/s]\n",
      "[Pod fsdp-fine-tuning-master-0]: Training is started...\n",
      "[Pod fsdp-fine-tuning-master-0]: Train Epoch: \t1, Loss: \t0.3802\n",
      "[Pod fsdp-fine-tuning-master-0]: Train Epoch: \t2, Loss: \t0.2659\n",
      "[Pod fsdp-fine-tuning-master-0]: FSDP training time: 107 seconds\n"
     ]
    }
   ],
   "source": [
    "logs, _ = TrainingClient().get_job_logs(job_name, follow=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-03-01T23:44:15.511173Z",
     "iopub.status.busy": "2024-03-01T23:44:15.510932Z",
     "iopub.status.idle": "2024-03-01T23:44:15.539921Z",
     "shell.execute_reply": "2024-03-01T23:44:15.539352Z",
     "shell.execute_reply.started": "2024-03-01T23:44:15.511155Z"
    },
    "tags": []
   },
   "source": [
    "## Delete the PyTorchJob\n",
    "\n",
    "You can delete the created PyTorchJob."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "TrainingClient().delete_job(name=job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
